// ModelParameters.java
//
// (c) 1999-2001 PAL Development Core Team
//
// This package may be distributed under the
// terms of the Lesser GNU General Public License (LGPL)


package pal.eval;

import pal.alignment.*;
import pal.substmodel.*;
import pal.tree.*;
import pal.distance.*;
import pal.math.*;
import pal.misc.*;

import java.io.*;

/**
 * estimates substitution model parameters from the data
 *
 * @version $Id: ModelParameters.java,v 1.7 2001/07/13 14:39:13 korbinian Exp $
 * @author Korbinian Strimmer
 */
public class ModelParameters implements MultivariateFunction
{
	//
	// public stuff
	//
	
	/** fractional digits desired for parameters */
	public final static int FRACDIGITS = 3;
	
	/**
	 * Constructor
	 *
	 * @param sp site pattern
	 * @param m substitution model
	 */
	public ModelParameters(SitePattern sp, SubstitutionModel m)
	{
		sitePattern = sp;
		model = m;
		numParams = model.getNumParameters();
		
		lv = new LikelihoodValue(sitePattern);
		lv.setModel(model);

		if (numParams == 1)
		{
			mvm = new OrthogonalSearch();
		}
		else
		{
			mvm = new ConjugateGradientSearch();
			//((ConjugateGradientSearch) mvm).prin = 2;
		}
	}
	
	/**
	 * estimate (approximate) values for the model parameters
	 * from the data using a neighbor-joining tree
	 *
	 * @return parameter estimates
	 */
	public double[] estimate()
	{			
		double[] p = new double[numParams];			
		for (int i = 0; i < numParams; i++)
		{
			p[i] = model.getDefaultValue(i);
		}
		double fp;
		
		AlignmentDistanceMatrix distMat = null;
		double tolfp = Math.pow(10, -1-FRACDIGITS);
		double tolp = Math.pow(10, -1-FRACDIGITS);
		boolean first = true;
		do
		{
			if (first)
			{		
				distMat = new AlignmentDistanceMatrix(sitePattern, model);;
			}
			else
			{
				distMat.recompute(model);
			}
			
			NeighborJoiningTree t = new NeighborJoiningTree(distMat);
			ParameterizedTree pt = new UnconstrainedTree(t);
						
			//ChiSquareValue cs = new ChiSquareValue(distMat, true);		
			//cs.setTree(pt);
			//cs.optimiseParameters();

			lv.setTree(pt);
			
			if (first)
			{
				fp = evaluate(p);
				mvm.stopCondition(fp, p, tolfp, tolp, true);
				first = false;
			}
			
			mvm.optimize(this, p, tolfp, tolp);
		
			fp = evaluate(p);			
		}
		while (!mvm.stopCondition(fp, p, tolfp, tolp, false));
				
		// trim p
		double m = Math.pow(10, FRACDIGITS);
		for (int i = 0;  i < p.length; i++)
		{
			p[i] = Math.round(p[i]*m)/m;
		}
		
		fp = evaluate(p);
		
		// Corresponding SEs
		double[] pSE = new double[numParams];
		pSE = NumericalDerivative.diagonalHessian(this, p);
		for (int i = 0; i < numParams; i++)
		{
			pSE[i] = Math.sqrt(1.0/pSE[i]);
			model.setParameterSE(pSE[i], i);
		}
		return p;
	}

	/**
	 * estimate (approximate) values for the model parameters
	 * from the data using a given (parameterized) tree
	 * @return parameter estimates
	 */		
	public double[] estimateFromTree(ParameterizedTree t)
	{			
		// there is a horrible amount of code duplication
		// here - should be cleaned up at some time
		
		
		double[] p = new double[numParams];			
		for (int i = 0; i < numParams; i++)
		{
			p[i] = model.getDefaultValue(i);
		}
		double fp;
		
		double tolfp = Math.pow(10, -1-FRACDIGITS);
		double tolp = Math.pow(10, -1-FRACDIGITS);
		boolean first = true;
		do
		{
			lv.setTree(t);
			
			if (first)
			{
				fp = evaluate(p);
				mvm.stopCondition(fp, p, tolfp, tolp, true);
				first = false;
			}
			
			mvm.optimize(this, p, tolfp, tolp);
		
			fp = evaluate(p);			
		}
		while (!mvm.stopCondition(fp, p, tolfp, tolp, false));
				
		// trim p
		double m = Math.pow(10, FRACDIGITS);
		for (int i = 0;  i < p.length; i++)
		{
			p[i] = Math.round(p[i]*m)/m;
		}
		
		fp = evaluate(p);
		
		// Corresponding SEs
		double[] pSE = new double[numParams];
		pSE = NumericalDerivative.diagonalHessian(this, p);
		for (int i = 0; i < numParams; i++)
		{
			pSE[i] = Math.sqrt(1.0/pSE[i]);
			model.setParameterSE(pSE[i], i);
		}
		return p;
	}


	
	// interface MultivariateFunction

	public double evaluate(double[] params)
	{
		// set model parameters
		for (int i = 0; i < numParams; i++)
		{
			model.setParameter(params[i], i);
		}

		double r = -lv.compute();
		//double r = -lv.optimiseParameters();

		return r;
	}
	
	public int getNumArguments()
	{
	 	return numParams;
	}
	
	public double getLowerBound(int n)
	{
		return model.getLowerLimit(n);
	}

	public double getUpperBound(int n)
	{
		return model.getUpperLimit(n);
	}


	//
	// Private stuff
	//
	
	private int numParams;
	private SubstitutionModel model;
	private SitePattern sitePattern;
	private ParameterizedTree tree;
	private LikelihoodValue lv;
	private MultivariateMinimum mvm;
}

